import torch
from torch import nn
from transformers import BertModel

from config import Config

class FraudClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        self.base = BertModel.from_pretrained(Config.Train.model_name)

        base_model_hidden_size = 768
        self.drop = nn.Dropout(p=0.05)
        self.classifier = nn.Linear(base_model_hidden_size, 5)

    def forward(self, input_ids, attention_mask):
        _, output = self.base(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)
        output = self.drop(output)
        return self.classifier(output)
